-- Copyright (C) Yichun Zhang (agentzh)


local sub = string.sub
local byte = string.byte
local null = ngx.null
local type = type
local pairs = pairs
local unpack = unpack
local setmetatable = setmetatable
local tonumber = tonumber
local tostring = tostring
local rawget = rawget
--local error = error


local ok, new_tab = pcall(require, "table.new")
if not ok or type(new_tab) ~= "function" then
	new_tab = function (narr, nrec) return {} end
end


local _M = new_tab(0, 54)

_M._VERSION = '0.26'


local common_cmds = {
"get",      "set",          "mget",     "mset",
"del",      "incr",         "decr",                 -- Strings
"llen",     "lindex",       "lpop",     "lpush",
"lrange",   "linsert",                              -- Lists
"hexists",  "hget",         "hset",     "hmget",
--[[ "hmset", ]]            "hdel",                 -- Hashes
"smembers", "sismember",    "sadd",     "srem",
"sdiff",    "sinter",       "sunion",               -- Sets
"zrange",   "zrangebyscore", "zrank",   "zadd",
"zrem",     "zincrby",                              -- Sorted Sets
"auth",     "eval",         "expire",   "script",
"sort"                                              -- Others
}


local sub_commands = {
"subscribe", "psubscribe"
}


local unsub_commands = {
"unsubscribe", "punsubscribe"
}


local mt = { __index = _M }


function _M.new(self)
	local tcp = nil
	local usngx = nil
	if ngx and ngx.can_socket() then
		tcp = ngx.socket.tcp
		usngx = true
	else
		tcp = require("socket.core").tcp
		usngx = false
	end
	local sock, err = tcp()
	if not sock then
		return nil, err
	end
	return setmetatable({ _sock = sock, _subscribed = false, _usngx = usngx }, mt)
end


function _M.set_timeout(self, timeout)
	local sock = rawget(self, "_sock")
	if not sock then
		return nil, "not initialized"
	end
	
	return sock:settimeout(timeout)
end


function _M.connect(self, ...)
	local sock = rawget(self, "_sock")
	if not sock then
		return nil, "not initialized"
	end
	
	self._subscribed = false
	
	return sock:connect(...)
end


function _M.set_keepalive(self, ...)
	local sock = rawget(self, "_sock")
	if not sock then
		return nil, "not initialized"
	end
	
	if rawget(self, "_subscribed") then
		return nil, "subscribed state"
	end
	
	return sock:setkeepalive(...)
end


function _M.get_reused_times(self)
	local sock = rawget(self, "_sock")
	if not sock then
		return nil, "not initialized"
	end
	
	return sock:getreusedtimes()
end


local function close(self)
	local sock = rawget(self, "_sock")
	if not sock then
		return nil, "not initialized"
	end
	
	return sock:close()
end
_M.close = close


local function _read_reply(self, sock)
	local line, err = sock:receive()
	if not line then
		if err == "timeout" and not rawget(self, "_subscribed") then
			sock:close()
		end
		return nil, err
	end
	
	local prefix = byte(line)
	
	if prefix == 36 then    -- char '$'
		-- print("bulk reply")
		
		local size = tonumber(sub(line, 2))
		if size < 0 then
			return null
		end
		
		local data, err = sock:receive(size)
		if not data then
			if err == "timeout" then
				sock:close()
			end
			return nil, err
		end
		
		local dummy, err = sock:receive(2) -- ignore CRLF
		if not dummy then
			return nil, err
		end
		
		return data
		
	elseif prefix == 43 then    -- char '+'
		-- print("status reply")
		
		return sub(line, 2)
		
	elseif prefix == 42 then -- char '*'
		local n = tonumber(sub(line, 2))
		
		-- print("multi-bulk reply: ", n)
		if n < 0 then
			return null
		end
		
		local vals = new_tab(n, 0)
		local nvals = 0
		for i = 1, n do
			local res, err = _read_reply(self, sock)
			if res then
				nvals = nvals + 1
				vals[nvals] = res
				
			elseif res == nil then
				return nil, err
				
			else
				-- be a valid redis error value
				nvals = nvals + 1
				vals[nvals] = {false, err}
			end
		end
		
		return vals
		
	elseif prefix == 58 then    -- char ':'
		-- print("integer reply")
		return tonumber(sub(line, 2))
		
	elseif prefix == 45 then    -- char '-'
		-- print("error reply: ", n)
		
		return false, sub(line, 2)
		
	else
		-- when `line` is an empty string, `prefix` will be equal to nil.
		return nil, "unkown prefix: \"" .. tostring(prefix) .. "\""
	end
end


local function _gen_req(args)
	local nargs = #args
	
	local req = new_tab(nargs * 5 + 1, 0)
	req[1] = "*" .. nargs .. "\r\n"
	local nbits = 2
	
	for i = 1, nargs do
		local arg = args[i]
		if type(arg) ~= "string" then
			arg = tostring(arg)
		end
		
		req[nbits] = "$"
		req[nbits + 1] = #arg
		req[nbits + 2] = "\r\n"
		req[nbits + 3] = arg
		req[nbits + 4] = "\r\n"
		
		nbits = nbits + 5
	end
	
	-- it is much faster to do string concatenation on the C land
	-- in real world (large number of strings in the Lua VM)
	return req
end

local function _tostring(data)
	local info = ""
	for _,v in ipairs(data) do
		info = info .. tostring(v)
	end
	return info
end

local function _do_cmd(self, ...)
	local args = {...}
	
	local sock = rawget(self, "_sock")
	if not sock then
		return nil, "not initialized"
	end
	
	local req = _gen_req(args)
	if self._usngx == false then
		req = _tostring(req)
	end
	
	local reqs = rawget(self, "_reqs")
	if reqs then
		reqs[#reqs + 1] = req
		return
	end
	
	-- print("request: ", table.concat(req))
	
	local bytes, err = sock:send(req)
	if not bytes then
		return nil, err
	end
	
	return _read_reply(self, sock)
end


local function _check_subscribed(self, res)
	if type(res) == "table"
		and (res[1] == "unsubscribe" or res[1] == "punsubscribe")
		and res[3] == 0
		then
		self._subscribed = false
	end
end


function _M.read_reply(self)
	local sock = rawget(self, "_sock")
	if not sock then
		return nil, "not initialized"
	end
	
	if not rawget(self, "_subscribed") then
		return nil, "not subscribed"
	end
	
	local res, err = _read_reply(self, sock)
	_check_subscribed(self, res)
	
	return res, err
end


for i = 1, #common_cmds do
	local cmd = common_cmds[i]
	
	_M[cmd] =
	function (self, ...)
		return _do_cmd(self, cmd, ...)
	end
end


for i = 1, #sub_commands do
	local cmd = sub_commands[i]
	
	_M[cmd] =
	function (self, ...)
		self._subscribed = true
		return _do_cmd(self, cmd, ...)
	end
end


for i = 1, #unsub_commands do
	local cmd = unsub_commands[i]
	
	_M[cmd] =
	function (self, ...)
		local res, err = _do_cmd(self, cmd, ...)
		_check_subscribed(self, res)
		return res, err
	end
end


function _M.hmset(self, hashname, ...)
	if select('#', ...) == 1 then
		local t = select(1, ...)
		
		local n = 0
		for k, v in pairs(t) do
			n = n + 2
		end
		
		local array = new_tab(n, 0)
		
		local i = 0
		for k, v in pairs(t) do
			array[i + 1] = k
			array[i + 2] = v
			i = i + 2
		end
		-- print("key", hashname)
		return _do_cmd(self, "hmset", hashname, unpack(array))
	end
	
	-- backwards compatibility
	return _do_cmd(self, "hmset", hashname, ...)
end


function _M.init_pipeline(self, n)
	self._reqs = new_tab(n or 4, 0)
end


function _M.cancel_pipeline(self)
	self._reqs = nil
end


function _M.commit_pipeline(self)
	local reqs = rawget(self, "_reqs")
	if not reqs then
		return nil, "no pipeline"
	end
	
	self._reqs = nil
	
	local sock = rawget(self, "_sock")
	if not sock then
		return nil, "not initialized"
	end
	
	local bytes, err = sock:send(reqs)
	if not bytes then
		return nil, err
	end
	
	local nvals = 0
	local nreqs = #reqs
	local vals = new_tab(nreqs, 0)
	for i = 1, nreqs do
		local res, err = _read_reply(self, sock)
		if res then
			nvals = nvals + 1
			vals[nvals] = res
			
		elseif res == nil then
			if err == "timeout" then
				close(self)
			end
			return nil, err
			
		else
			-- be a valid redis error value
			nvals = nvals + 1
			vals[nvals] = {false, err}
		end
	end
	
	return vals
end


function _M.array_to_hash(self, t)
	local n = #t
	-- print("n = ", n)
	local h = new_tab(0, n / 2)
	for i = 1, n, 2 do
		h[t[i]] = t[i + 1]
	end
	return h
end


-- this method is deperate since we already do lazy method generation.
function _M.add_commands(...)
	local cmds = {...}
	for i = 1, #cmds do
		local cmd = cmds[i]
		_M[cmd] =
		function (self, ...)
			return _do_cmd(self, cmd, ...)
		end
	end
end


setmetatable(_M, {__index = function (self, cmd)
	local method =
	function (self, ...)
		return _do_cmd(self, cmd, ...)
	end
	
	-- cache the lazily generated method in our
	-- module table
	_M[cmd] = method
	return method
end})


return _M
